{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluating Detectors\n",
    "\n",
    "In `scikit-clean`, A `Detector` only identifies/detects the mislabelled samples. It's not a complete classifier (rather a part of one). So procedure for their evaluation is different.\n",
    "\n",
    "We can view a noise detector as a binary classifier: it's job is to provide a probability denoting if a sample is \"mislabelled\" or \"clean\". We can therefore use binary classification metrics that work on continuous output: brier score, log loss, area under ROC curve etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Suppress warnings, you should remove this before modifying this notebook\n",
    "def warn(*args, **kwargs):\n",
    "    pass\n",
    "import warnings\n",
    "warnings.warn = warn\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.datasets import make_classification\n",
    "from sklearn.metrics import brier_score_loss, log_loss, roc_auc_score\n",
    "\n",
    "from skclean.tests.common_stuff import NOISE_DETECTORS  # All noise detectors in skclean\n",
    "from skclean.utils import load_data \n",
    "from skclean.detectors.base import BaseDetector\n",
    "from skclean.simulate_noise import flip_labels_uniform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DummyDetector(BaseDetector):\n",
    "    def detect(self, X, y):\n",
    "        return np.random.uniform(size=y.shape)\n",
    "\n",
    "from skclean.detectors import KDN, RkDN\n",
    "class WkDN:\n",
    "    def detect(self,X,y):\n",
    "        return .5 * KDN().detect(X,y) + .5 * RkDN().detect(X,y)\n",
    "    \n",
    "ALL_DETECTOTS = [DummyDetector(), WkDN()] + NOISE_DETECTORS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = make_classification(1800, 10)\n",
    "#X, y = load_data('breast_cancer')\n",
    "\n",
    "yn = flip_labels_uniform(y, .3)  # 30% label noise\n",
    "clean_idx = (y==yn)              # Indices of correctly labelled samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>log</th>\n",
       "      <th>brier</th>\n",
       "      <th>roc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>DummyDetector</th>\n",
       "      <td>0.999</td>\n",
       "      <td>0.333</td>\n",
       "      <td>0.501</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>WkDN</th>\n",
       "      <td>0.664</td>\n",
       "      <td>0.183</td>\n",
       "      <td>0.811</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>ForestKDN</th>\n",
       "      <td>1.099</td>\n",
       "      <td>0.131</td>\n",
       "      <td>0.858</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>InstanceHardness</th>\n",
       "      <td>0.448</td>\n",
       "      <td>0.141</td>\n",
       "      <td>0.902</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>KDN</th>\n",
       "      <td>0.830</td>\n",
       "      <td>0.173</td>\n",
       "      <td>0.818</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RkDN</th>\n",
       "      <td>3.371</td>\n",
       "      <td>0.227</td>\n",
       "      <td>0.749</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>MCS</th>\n",
       "      <td>0.294</td>\n",
       "      <td>0.071</td>\n",
       "      <td>0.955</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>PartitioningDetector</th>\n",
       "      <td>0.942</td>\n",
       "      <td>0.072</td>\n",
       "      <td>0.950</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>RandomForestDetector</th>\n",
       "      <td>0.464</td>\n",
       "      <td>0.145</td>\n",
       "      <td>0.908</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                        log  brier    roc\n",
       "DummyDetector         0.999  0.333  0.501\n",
       "WkDN                  0.664  0.183  0.811\n",
       "ForestKDN             1.099  0.131  0.858\n",
       "InstanceHardness      0.448  0.141  0.902\n",
       "KDN                   0.830  0.173  0.818\n",
       "RkDN                  3.371  0.227  0.749\n",
       "MCS                   0.294  0.071  0.955\n",
       "PartitioningDetector  0.942  0.072  0.950\n",
       "RandomForestDetector  0.464  0.145  0.908"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame()\n",
    "for d in ALL_DETECTOTS:\n",
    "    conf_score = d.detect(X, yn)\n",
    "    for name,loss_func in zip(['log','brier','roc'],\n",
    "                         [log_loss, brier_score_loss, roc_auc_score]):\n",
    "        loss = loss_func(clean_idx, conf_score)\n",
    "        df.at[d.__class__.__name__,name] = np.round(loss,3)\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that in case of `roc_auc_score`, higher is better."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
